import sys M = 998244353; R = [1] def ntt(P): n = k = len(P); P = [*P]; Z = [0]*n while 2*len(R) < n: u = pow(3, M//(4*len(R)), M); R.extend([r*u%M for r in R]) # 3 is a primitive root of M while k > 1: for i in range(n//k): r = R[i] for j in range(i*k, i*k+k//2): z = r*P[j+k//2]; P[j+k//2] = (P[j]-z)%M; P[j] = (P[j]+z)%M k >>= 1 for i in range(1, n): Z[i] = Z[i//2]//2+(i&1)*n//2 return [P[r] for r in Z] def mult(p1, p2): m = len(p1)+len(p2)-1; n = 1 while n < m: n *= 2 p1 += [0]*(n-len(p1)); p2 += [0]*(n-len(p2)); ntt1 = ntt(p1); ntt2 = ntt(p2) z = pow(n, -1, M); return ntt([ntt1[-i]*ntt2[-i]%M*z%M for i in range(n)])[:m] n = int(input()); e = {0} for line in sys.stdin: n -= 1; e.add(int(line)) if n == 0: break m = int(input()); d = [*map(int, sys.stdin)] p = [0]*(max(e) + 1) for i in e: p[i] = 1 res = mult(p, p) print(sum(map(lambda x: x0, d)))